# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string, re
import sympy as sm
import numpy as np
import itertools as it
from hysop import __KERNEL_DEBUG__, vprint, dprint
from hysop.backend.device.opencl import cl, clArray, clTypes
from hysop.tools.numerics import MPZ, MPQ, MPFR, F2Q
from hysop.tools.htypes import first_not_None, to_tuple
vsizes = [1, 2, 3, 4, 8, 16]
base_types = ["float", "signed", "unsigned"]
float_base_types = ["half", "float", "double"]
signed_base_types = ["char", "short", "int", "long"]
unsigned_base_types = ["uchar", "ushort", "uint", "ulong"]
float_types = []
signed_types = []
unsigned_types = []
for b in base_types:
b_base_types = eval(b + "_base_types")
b_types = eval(b + "_types")
for f, c in it.product(b_base_types, vsizes):
if c == 1:
if f == "half":
continue
else:
ftype = f
else:
ftype = f + str(c)
b_types.append(ftype)
integer_types = signed_types + unsigned_types
builtin_types = integer_types + float_types
float_base_type_require = {
"half": "cl_khr_fp16",
"float": None,
"double": "cl_khr_fp64",
}
FLT_DIG = {
"half": 3, # = HALF_DIG
"float": 6, # = FLT_DIG
"double": 15, # = DBL_DIG
}
FLT_MANT_DIG = {
"half": 11, # = HALF_MANT_DIG
"float": 24, # = FLT_MANT_DIG
"double": 53, # = DBL_MANT_DIG
}
FLT_LITERAL = {"half": "h", "float": "f", "double": ""}
FLT_BYTES = {"half": 2, "float": 4, "double": 8}
[docs]
def basetype(fulltype):
return fulltype.translate(str.maketrans("", "", string.digits))
[docs]
def components(fulltype):
comp = fulltype.translate(str.maketrans("", "", string.ascii_letters + "_"))
return 1 if comp == "" else int(comp)
[docs]
def mangle_vtype(fulltype):
return basetype(fulltype)[0] + str(components(fulltype))
[docs]
def vtype(basetype, N):
return basetype + ("" if N == 1 else str(N))
[docs]
def itype(fulltype):
N = components(fulltype)
return "int" + ("" if N == 1 else str(N))
[docs]
def uitype(fulltype):
N = components(fulltype)
return "uint" + ("" if N == 1 else str(N))
[docs]
def np_dtype(fulltype):
return cl.tools.get_or_register_dtype(fulltype)
[docs]
def vtype_component_adressing(i, mode="hex"):
if mode == "hex":
return "0123456789abcdef"[i]
elif mode == "HEX":
return "0123456789ABCDEF"[i]
elif mode == "pos":
return "xyzw"[i]
else:
raise ValueError("Bad vtype component adressing mode!")
[docs]
def vtype_access(i, N, mode="hex"):
assert i < N
if N == 1:
return ""
else:
return ("s" if mode.lower() == "hex" else "") + vtype_component_adressing(
i, mode
)
[docs]
def float_to_hex_str(f, fbtype):
if f != f:
return "NAN"
sf = float(f).hex().split("0x") + [""]
buf = sf[1].split("p")
mantissa = buf[0]
exponent = buf[1]
mant_dig = FLT_MANT_DIG[fbtype]
literal = FLT_LITERAL[fbtype]
nhex = (mant_dig - 1 + 3) // 4 + 2
# +2= leading one or zero and decimal point characters (1.abde... or 0.abcde...)
sf[0] = ("+" if sf[0] == "" else sf[0]) + "0x"
sf[1] = mantissa[:nhex]
sf[2] = "p" + exponent + literal
return "".join(sf)
[docs]
def float_to_dec_str(f, fbtype):
"""
sf = (sign, mantissa, exponent)
"""
if f != f:
return "NAN"
sf = float(f).__repr__().split(".")
if len(sf) == 1:
return sf[0]
sf += (3 - len(sf)) * [None]
buf = sf[1].split("e")
mantissa = buf[0]
exponent = buf[1] if len(buf) > 1 else None
dig = FLT_DIG[fbtype]
literal = FLT_LITERAL[fbtype]
sig = len(sf[0].replace("+", "").replace("-", "").lstrip("0"))
sf[0] = "+" if (sf[0] == "") else sf[0] + "."
sf[1] = mantissa[: dig - sig + 1]
sf[2] = "e" + exponent + literal if (exponent is not None) else literal
return "".join(sf)
# pyopencl specific
vec = clTypes
[docs]
def npmake(dtype):
return lambda scalar: dtype(scalar) # np.array([scalar], dtype=dtype)
vtype_int = [np.int32, vec.int2, vec.int3, vec.int4, vec.int8, vec.int16]
vtype_uint = [np.uint32, vec.uint2, vec.uint3, vec.uint4, vec.uint8, vec.uint16]
vtype_simple = [np.float32, vec.float2, vec.float3, vec.float4, vec.float8, vec.float16]
vtype_double = [
np.float64,
vec.double2,
vec.double3,
vec.double4,
vec.double8,
vec.double16,
]
cl_vec_types = vtype_int + vtype_uint + vtype_simple + vtype_double
make_int = [
npmake(np.int32),
vec.make_int2,
vec.make_int3,
vec.make_int4,
vec.make_int8,
vec.make_int16,
]
make_uint = [
npmake(np.uint32),
vec.make_uint2,
vec.make_uint3,
vec.make_uint4,
vec.make_uint8,
vec.make_uint16,
]
make_simple = [
npmake(np.float32),
vec.make_float2,
vec.make_float3,
vec.make_float4,
vec.make_float8,
vec.make_float16,
]
make_double = [
npmake(np.float64),
vec.make_double2,
vec.make_double3,
vec.make_double4,
vec.make_double8,
vec.make_double16,
]
[docs]
def simplen(n):
if n == 1:
return np.float32
i = vsizes.index(n)
return vtype_simple[i]
[docs]
def doublen(n):
if n == 1:
return np.float64
i = vsizes.index(n)
return vtype_double[i]
[docs]
def intn(n):
if n == 1:
return np.int32
i = vsizes.index(n)
return vtype_int[i]
[docs]
def uintn(n):
if n == 1:
return np.uint32
i = vsizes.index(n)
return vtype_uint[i]
_typen = {
"float": simplen,
"simple": simplen,
"double": doublen,
"int": intn,
"uint": uintn,
}
[docs]
def typen(btype, n):
return _typen[btype](n)
[docs]
def make_simplen(vals, n, dval=0):
vals = to_tuple(vals)
vals += (dval,) * (n - len(vals))
i = vsizes.index(n)
return make_simple[i](*vals)
[docs]
def make_doublen(vals, n, dval=0):
vals = to_tuple(vals)
vals += (dval,) * (n - len(vals))
i = vsizes.index(n)
return make_double[i](*vals)
[docs]
def make_intn(vals, n, dval=0):
vals = to_tuple(vals)
vals += (dval,) * (n - len(vals))
i = vsizes.index(n)
return make_int[i](*vals)
[docs]
def make_uintn(vals, n, dval=0):
vals = to_tuple(vals)
vals += (dval,) * (n - len(vals))
i = vsizes.index(n)
return make_uint[i](*vals)
_make_typen = {
"float": make_simplen,
"simple": make_simplen,
"double": make_doublen,
"int": make_intn,
"uint": make_uintn,
}
[docs]
def make_typen(btype):
return _make_typen[btype]
[docs]
def cl_type_to_dtype(cl_type):
btype = basetype(cl_type)
N = components(cl_type)
return typen(btype, N)
[docs]
def cl_vec_type_to_scalar_and_count(cl_vec_type):
assert cl_vec_type in cl_vec_types
cvt = cl_vec_type
for vtypes in (vtype_int, vtype_uint, vtype_simple, vtype_double):
if cvt in vtypes:
btype = vtypes[0]
count = vsizes[vtypes.index(cvt)]
return (btype, count)
msg = "cl_vec_types != U(vtype_*)"
raise RuntimeError(msg)
[docs]
class TypeGen:
def __init__(self, fbtype="float", float_dump_mode="dec"):
self.float_base_types = float_base_types
self.FLT_BYTES = FLT_BYTES
self.FLT_DIG = FLT_DIG
self.FLT_MANT_DIG = FLT_MANT_DIG
self.FLT_LITERAL = FLT_LITERAL
self.np_dtype = np_dtype
self.float_to_dec_str = float_to_dec_str
self.float_to_hex_str = float_to_hex_str
self.fbtype = fbtype
self.float_dump_mode = float_dump_mode
if float_dump_mode in ["hex", "hexadecimal"]:
self.float_to_str = float_to_hex_str
elif float_dump_mode in ["dec", "decimal"]:
self.float_to_str = float_to_dec_str
else:
raise ValueError(f"Unknown float_dump_mode '{float_dump_mode}'")
[docs]
def dump(self, val):
if isinstance(val, (list, tuple, dict, np.ndarray)):
if isinstance(val, (list, tuple)) and len(val) == 1:
val = val[0]
elif isinstance(val, np.ndarray) and val.size == 1:
val = val.flatten()[0]
else:
raise ValueError(f"Value is not a scalar, got {val}.")
if isinstance(val, (float, np.floating, MPFR, sm.Float)):
sval = self.float_to_str(val, self.fbtype)
return f"({sval})"
elif isinstance(val, (np.integer, int, MPZ, sm.Integer)):
suffix = ""
if isinstance(val, np.unsignedinteger):
suffix += "u"
if isinstance(val, (np.int64, np.uint64, MPZ)):
suffix += "L"
sign = "" if val == 0 else ("+" if val > 0 else "-")
sval = str(val)
if val < 0:
sval = sval[1:]
if val != 0:
sval = f"({sign}{sval}{suffix})"
else:
sval = "0" + suffix
return sval
elif isinstance(val, (bool, np.bool_)):
return "true" if val else "false"
elif isinstance(val, (MPQ, sm.Rational)):
if not __KERNEL_DEBUG__:
return self.dump(float(val))
if isinstance(val, MPQ):
if val.denominator == 1:
return str(val.numerator)
else:
return "({}.0{f}/{}.0{f})".format(
val.numerator, val.denominator, f=FLT_LITERAL[self.fbtype]
)
elif isinstance(val, sm.Rational):
if val.q == 1:
return str(val.p)
else:
val = "({}.0{f}/{}.0{f})".format(
val.p, val.q, f=FLT_LITERAL[self.fbtype]
)
return val
else:
assert False
elif isinstance(val, str):
return val
else:
# msg='Unknown value type {}.\n__mro__ is:\n *{}'.format(type(val), '\n *'.join(str(x) for x in type(val).__mro__))
# raise NotImplementedError(msg)
return str(val)
[docs]
def dumped_type(self, val):
if isinstance(val, (list, tuple, dict, np.ndarray)):
if isinstance(val, (list, tuple)) and len(val) == 1:
val = val[0]
elif isinstance(val, np.ndarray) and val.size == 1:
val = val.flatten()[0]
else:
raise ValueError(f"Value is not a scalar, got {val}.")
if isinstance(val, (float, np.floating, MPFR, sm.Float)):
return self.fbtype
elif isinstance(val, (np.integer, int, MPZ, sm.Integer)):
if isinstance(val, (np.int64, MPZ)):
return "long"
elif isinstance(val, np.uint64):
return "ulong"
elif isinstance(val, np.unsignedinteger):
return "uint"
elif isinstance(val, int):
return "long"
else:
return "int"
elif isinstance(val, (bool, np.bool_)):
return "bool"
elif isinstance(val, (MPQ, sm.Rational)):
return self.fbtype
else:
return None
# struct type generation (type size and struct field offsets) is different for each device
# depending on architecture and compiler implementation and features.
# /!\ do not use the same opencl typegen instance for two different devices that are
# not equivalent.
[docs]
class OpenClTypeGen(TypeGen):
[docs]
@staticmethod
def devicelessTypegen():
"""
Sometimes we do not need structs and code generation is device independent.
"""
return OpenClTypeGen(device=None, context=None, platform=None)
def __init__(
self,
device,
context,
platform,
fbtype="float",
float_dump_mode="dec",
use_short_circuit_ops=False,
unroll_loops=False,
):
super().__init__(fbtype, float_dump_mode)
self.device = device
self.context = context
self.platform = platform
self.use_short_circuit_ops = use_short_circuit_ops
self.unroll_loops = unroll_loops
self.vsizes = vsizes
self.signed_base_types = signed_base_types
self.unsigned_base_types = unsigned_base_types
self.integer_base_types = signed_base_types + unsigned_base_types
self.float_types = float_types
self.signed_types = signed_types
self.unsigned_types = unsigned_types
self.integer_types = integer_types
self.builtin_types = builtin_types
self.float_base_type_require = float_base_type_require
self.basetype = basetype
self.components = components
self.vtype = vtype
self.itype = itype
self.uitype = uitype
self.np_dtype = np_dtype
self.vtype_component_adressing = vtype_component_adressing
self.vtype_access = vtype_access
self.mangle_vtype = mangle_vtype
self.float_to_dec_str = float_to_dec_str
self.float_to_hex_str = float_to_hex_str
# pyopencl specifics
self.intn = intn
self.uintn = uintn
self.simplen = simplen
self.doublen = doublen
self.typen = typen
self.make_intn = make_intn
self.make_uintn = make_uintn
self.make_simplen = make_simplen
self.make_doublen = make_doublen
self.make_typen = make_typen
if fbtype == "float":
self.floatn = simplen
self.make_floatn = make_simplen
self.dtype = np.float32
elif fbtype == "double":
self.floatn = doublen
self.make_floatn = make_doublen
self.dtype = np.float64
elif fbtype == "half":
self.floatn = halfn
self.make_floatn = make_halfn
self.dtype = np.float16
else:
raise ValueError(f"Unknown fbtype '{fbtype}'")
self._ftype_build_options = self.get_precision_opts()
[docs]
def ftype_build_options(self):
return self._ftype_build_options
[docs]
def device_has_ftype(self, device):
dev_exts = device.extensions.split(" ")
req = self.float_base_type_require[self.fbtype]
return (req is None) or (req[0] in dev_exts)
[docs]
def cl_requirements(self):
return [self.float_base_type_require[self.fbtype]]
[docs]
def opencl_version_greater(self, major, minor):
(cl_major, cl_minor) = self.opencl_version()
if cl_major < major:
return False
if (cl_major == major) and (cl_minor <= minor):
return False
return True
[docs]
def opencl_version(self):
assert self.device is not None
sversion = self.device.version.strip()
_regexp = r"OpenCL\s+(\d)\.(\d)"
regexp = re.compile(_regexp)
match = re.match(regexp, sversion)
if not match:
msg = "Could not extract OpenCL version from device returned version '{}' "
msg += "and regular expression '{}'."
msg = msg.format(sversion, _regexp)
raise RuntimeError(msg)
major = int(match.group(1))
minor = int(match.group(2))
return (major, minor)
[docs]
def dtype_from_str(self, stype):
stype = stype.replace("ftype", self.fbtype).replace("fbtype", self.fbtype)
btype = basetype(stype)
N = components(stype)
return typen(btype, N)
[docs]
def dump_expr(self, expr, symbol2vars=None, **printer_settings):
"""
Print sympy expression expr as OpenCL code.
Sympy symbols may be replaced using symbol2vars dictionnary.
This dumper uses OpenClTypeGen.dump for floats and quotients.
See hysop.backend.device.opencl.opencl_printer.OpenClPrinter
"""
from hysop.backend.device.opencl.opencl_printer import OpenClPrinter
printer = OpenClPrinter(
typegen=self, symbol2vars=symbol2vars, **printer_settings
)
return printer.doprint(expr)
[docs]
def __repr__(self):
"""Used to hash in OpenClKernelAutotuner.autotuner_config_key()"""
return "{}_{}_{}_{}_{}_{}".format(
self.platform.name,
self.device.name,
self.fbtype,
self.float_dump_mode,
self.use_short_circuit_ops,
self.unroll_loops,
)
[docs]
def get_precision_opts(self):
"""
Check if device is capable to work with given precision
and returns build options considering this precision
"""
opts = []
# Precision supported
fp32_rounding_flag = True
if self.fbtype == "half":
if self.device.half_fp_config <= 0:
raise ValueError("Half precision is not supported on device.")
Prec = "half"
elif self.fbtype == "float":
opts.append("-cl-single-precision-constant")
prec = "single"
elif self.fbtype == "double":
if self.device.double_fp_config <= 0:
raise ValueError("Double Precision is not supported on device")
prec = "double"
return opts